{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|default_exp parallel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "from fastcore.imports import *\n",
    "from fastcore.basics import *\n",
    "from fastcore.foundation import *\n",
    "from fastcore.meta import *\n",
    "from fastcore.xtras import *\n",
    "from functools import wraps\n",
    "\n",
    "import concurrent.futures,time\n",
    "from multiprocessing import Process,Queue,Manager,set_start_method,get_all_start_methods,get_context\n",
    "from threading import Thread\n",
    "try:\n",
    "    if sys.platform == 'darwin' and IN_NOTEBOOK: set_start_method(\"fork\")\n",
    "except: pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastcore.test import *\n",
    "from nbdev.showdoc import *\n",
    "from fastcore.nb_imports import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Parallel\n",
    "\n",
    "> Threading and multiprocessing functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def threaded(process=False):\n",
    "    \"Run `f` in a `Thread` (or `Process` if `process=True`), and returns it\"\n",
    "    def _r(f):\n",
    "        def g(_obj_td, *args, **kwargs):\n",
    "            res = f(*args, **kwargs)\n",
    "            _obj_td.result = res\n",
    "        @wraps(f)\n",
    "        def _f(*args, **kwargs):\n",
    "            res = (Thread,Process)[process](target=g, args=args, kwargs=kwargs)\n",
    "            res._args = (res,)+res._args\n",
    "            res.start()\n",
    "            return res\n",
    "        return _f\n",
    "    if callable(process):\n",
    "        o = process\n",
    "        process = False\n",
    "        return _r(o)\n",
    "    return _r"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "first\n",
      "second\n"
     ]
    }
   ],
   "source": [
    "@threaded\n",
    "def _1():\n",
    "    time.sleep(0.05)\n",
    "    print(\"second\")\n",
    "    return 5\n",
    "\n",
    "@threaded\n",
    "def _2():\n",
    "    time.sleep(0.01)\n",
    "    print(\"first\")\n",
    "\n",
    "a = _1()\n",
    "_2()\n",
    "time.sleep(0.1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After the thread is complete, the return value is stored in the `result` attr."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "5"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#| eval: false\n",
    "a.result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def startthread(f):\n",
    "    \"Like `threaded`, but start thread immediately\"\n",
    "    return threaded(f)()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "first\n",
      "second\n"
     ]
    }
   ],
   "source": [
    "@startthread\n",
    "def _():\n",
    "    time.sleep(0.05)\n",
    "    print(\"second\")\n",
    "\n",
    "@startthread\n",
    "def _():\n",
    "    time.sleep(0.01)\n",
    "    print(\"first\")\n",
    "\n",
    "time.sleep(0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def startproc(f):\n",
    "    \"Like `threaded(True)`, but start Process immediately\"\n",
    "    return threaded(True)(f)()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "first\n",
      "second\n"
     ]
    }
   ],
   "source": [
    "@startproc\n",
    "def _():\n",
    "    time.sleep(0.05)\n",
    "    print(\"second\")\n",
    "\n",
    "@startproc\n",
    "def _():\n",
    "    time.sleep(0.01)\n",
    "    print(\"first\")\n",
    "\n",
    "time.sleep(0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def _call(lock, pause, n, g, item):\n",
    "    l = False\n",
    "    if pause:\n",
    "        try:\n",
    "            l = lock.acquire(timeout=pause*(n+2))\n",
    "            time.sleep(pause)\n",
    "        finally:\n",
    "            if l: lock.release()\n",
    "    return g(item)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def parallelable(param_name, num_workers, f=None):\n",
    "    f_in_main = f == None or sys.modules[f.__module__].__name__ == \"__main__\"\n",
    "    if sys.platform == \"win32\" and IN_NOTEBOOK and num_workers > 0 and f_in_main:\n",
    "        print(\"Due to IPython and Windows limitation, python multiprocessing isn't available now.\")\n",
    "        print(f\"So `{param_name}` has to be changed to 0 to avoid getting stuck\")\n",
    "        return False\n",
    "    return True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class ThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):\n",
    "    \"Same as Python's ThreadPoolExecutor, except can pass `max_workers==0` for serial execution\"\n",
    "    def __init__(self, max_workers=defaults.cpus, on_exc=print, pause=0, **kwargs):\n",
    "        if max_workers is None: max_workers=defaults.cpus\n",
    "        store_attr()\n",
    "        self.not_parallel = max_workers==0\n",
    "        if self.not_parallel: max_workers=1\n",
    "        super().__init__(max_workers, **kwargs)\n",
    "\n",
    "    def map(self, f, items, *args, timeout=None, chunksize=1, **kwargs):\n",
    "        if self.not_parallel == False: self.lock = Manager().Lock()\n",
    "        g = partial(f, *args, **kwargs)\n",
    "        if self.not_parallel: return map(g, items)\n",
    "        _g = partial(_call, self.lock, self.pause, self.max_workers, g)\n",
    "        try: return super().map(_g, items, timeout=timeout, chunksize=chunksize)\n",
    "        except Exception as e: self.on_exc(e)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "### ThreadPoolExecutor\n",
       "\n",
       ">      ThreadPoolExecutor (max_workers=8, on_exc=<built-infunctionprint>,\n",
       ">                          pause=0, **kwargs)\n",
       "\n",
       "Same as Python's ThreadPoolExecutor, except can pass `max_workers==0` for serial execution"
      ],
      "text/plain": [
       "<nbdev.showdoc.BasicMarkdownRenderer>"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(ThreadPoolExecutor, title_level=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "@delegates()\n",
    "class ProcessPoolExecutor(concurrent.futures.ProcessPoolExecutor):\n",
    "    \"Same as Python's ProcessPoolExecutor, except can pass `max_workers==0` for serial execution\"\n",
    "    def __init__(self, max_workers=defaults.cpus, on_exc=print, pause=0, **kwargs):\n",
    "        if max_workers is None: max_workers=defaults.cpus\n",
    "        store_attr()\n",
    "        self.not_parallel = max_workers==0\n",
    "        if self.not_parallel: max_workers=1\n",
    "        super().__init__(max_workers, **kwargs)\n",
    "\n",
    "    def map(self, f, items, *args, timeout=None, chunksize=1, **kwargs):\n",
    "        if not parallelable('max_workers', self.max_workers, f): self.max_workers = 0\n",
    "        self.not_parallel = self.max_workers==0\n",
    "        if self.not_parallel: self.max_workers=1\n",
    "\n",
    "        if self.not_parallel == False: self.lock = Manager().Lock()\n",
    "        g = partial(f, *args, **kwargs)\n",
    "        if self.not_parallel: return map(g, items)\n",
    "        _g = partial(_call, self.lock, self.pause, self.max_workers, g)\n",
    "        try: return super().map(_g, items, timeout=timeout, chunksize=chunksize)\n",
    "        except Exception as e: self.on_exc(e)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "### ProcessPoolExecutor\n",
       "\n",
       ">      ProcessPoolExecutor (max_workers=8, on_exc=<built-infunctionprint>,\n",
       ">                           pause=0, mp_context=None, initializer=None,\n",
       ">                           initargs=())\n",
       "\n",
       "Same as Python's ProcessPoolExecutor, except can pass `max_workers==0` for serial execution"
      ],
      "text/plain": [
       "<nbdev.showdoc.BasicMarkdownRenderer>"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(ProcessPoolExecutor, title_level=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "try: from fastprogress import progress_bar\n",
    "except: progress_bar = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def parallel(f, items, *args, n_workers=defaults.cpus, total=None, progress=None, pause=0,\n",
    "             method=None, threadpool=False, timeout=None, chunksize=1, **kwargs):\n",
    "    \"Applies `func` in parallel to `items`, using `n_workers`\"\n",
    "    kwpool = {}\n",
    "    if threadpool: pool = ThreadPoolExecutor\n",
    "    else:\n",
    "        if not method and sys.platform == 'darwin': method='fork'\n",
    "        if method: kwpool['mp_context'] = get_context(method)\n",
    "        pool = ProcessPoolExecutor\n",
    "    with pool(n_workers, pause=pause, **kwpool) as ex:\n",
    "        r = ex.map(f,items, *args, timeout=timeout, chunksize=chunksize, **kwargs)\n",
    "        if progress and progress_bar:\n",
    "            if total is None: total = len(items)\n",
    "            r = progress_bar(r, total=total, leave=False)\n",
    "        return L(r)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def add_one(x, a=1):\n",
    "    # this import is necessary for multiprocessing in notebook on windows\n",
    "    import random\n",
    "    time.sleep(random.random()/80)\n",
    "    return x+a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inp,exp = range(50),range(1,51)\n",
    "\n",
    "test_eq(parallel(add_one, inp, n_workers=2, progress=False), exp)\n",
    "test_eq(parallel(add_one, inp, threadpool=True, n_workers=2, progress=False), exp)\n",
    "test_eq(parallel(add_one, inp, n_workers=1, a=2), range(2,52))\n",
    "test_eq(parallel(add_one, inp, n_workers=0), exp)\n",
    "test_eq(parallel(add_one, inp, n_workers=0, a=2), range(2,52))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Use the `pause` parameter to ensure a pause of `pause` seconds between processes starting. This is in case there are race conditions in starting some process, or to stagger the time each process starts, for example when making many requests to a webserver. Set `threadpool=True` to use `ThreadPoolExecutor` instead of `ProcessPoolExecutor`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datetime import datetime"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 2022-08-07 05:10:05.999916\n",
      "1 2022-08-07 05:10:06.252031\n",
      "2 2022-08-07 05:10:06.503603\n",
      "3 2022-08-07 05:10:06.755216\n",
      "4 2022-08-07 05:10:07.006702\n"
     ]
    }
   ],
   "source": [
    "def print_time(i): \n",
    "    time.sleep(random.random()/1000)\n",
    "    print(i, datetime.now())\n",
    "\n",
    "parallel(print_time, range(5), n_workers=2, pause=0.25);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(#8) [0,2,4,6,8,10,12,14]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#|hide\n",
    "def die_sometimes(x):\n",
    "#     if 3<x<6: raise Exception(f\"exc: {x}\")\n",
    "    return x*2\n",
    "\n",
    "parallel(die_sometimes, range(8))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def run_procs(f, f_done, args):\n",
    "    \"Call `f` for each item in `args` in parallel, yielding `f_done`\"\n",
    "    processes = L(args).map(Process, args=arg0, target=f)\n",
    "    for o in processes: o.start()\n",
    "    yield from f_done()\n",
    "    processes.map(Self.join())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def _f_pg(obj, queue, batch, start_idx):\n",
    "    for i,b in enumerate(obj(batch)): queue.put((start_idx+i,b))\n",
    "\n",
    "def _done_pg(queue, items): return (queue.get() for _ in items)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export \n",
    "def parallel_gen(cls, items, n_workers=defaults.cpus, **kwargs):\n",
    "    \"Instantiate `cls` in `n_workers` procs & call each on a subset of `items` in parallel.\"\n",
    "    if not parallelable('n_workers', n_workers): n_workers = 0\n",
    "    if n_workers==0:\n",
    "        yield from enumerate(list(cls(**kwargs)(items)))\n",
    "        return\n",
    "    batches = L(chunked(items, n_chunks=n_workers))\n",
    "    idx = L(itertools.accumulate(0 + batches.map(len)))\n",
    "    queue = Queue()\n",
    "    if progress_bar: items = progress_bar(items, leave=False)\n",
    "    f=partial(_f_pg, cls(**kwargs), queue)\n",
    "    done=partial(_done_pg, queue, items)\n",
    "    yield from run_procs(f, done, L(batches,idx).zip())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# class _C:\n",
    "#     def __call__(self, o): return ((i+1) for i in o)\n",
    "\n",
    "# items = range(5)\n",
    "\n",
    "# res = L(parallel_gen(_C, items, n_workers=0))\n",
    "# idxs,dat1 = zip(*res.sorted(itemgetter(0)))\n",
    "# test_eq(dat1, range(1,6))\n",
    "\n",
    "# res = L(parallel_gen(_C, items, n_workers=3))\n",
    "# idxs,dat2 = zip(*res.sorted(itemgetter(0)))\n",
    "# test_eq(dat2, dat1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`cls` is any class with `__call__`. It will be passed `args` and `kwargs` when initialized. Note that `n_workers` instances of `cls` are created, one in each process. `items` are then split in `n_workers` batches and one is sent to each `cls`. The function then returns a generator of tuples of item indices and results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "<style>\n",
       "    /* Turns off some styling */\n",
       "    progress {\n",
       "        /* gets rid of default border in Firefox and Opera. */\n",
       "        border: none;\n",
       "        /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "        background-size: auto;\n",
       "    }\n",
       "    progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
       "        background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
       "    }\n",
       "    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "        background: #F44336;\n",
       "    }\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "class TestSleepyBatchFunc:\n",
    "    \"For testing parallel processes that run at different speeds\"\n",
    "    def __init__(self): self.a=1\n",
    "    def __call__(self, batch):\n",
    "        for k in batch:\n",
    "            time.sleep(random.random()/4)\n",
    "            yield k+self.a\n",
    "\n",
    "x = np.linspace(0,0.99,20)\n",
    "\n",
    "res = L(parallel_gen(TestSleepyBatchFunc, x, n_workers=2))\n",
    "test_eq(res.sorted().itemgot(1), x+1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# #|hide\n",
    "# from subprocess import Popen, PIPE\n",
    "# # test num_workers > 0 in scripts works when python process start method is spawn\n",
    "# process = Popen([\"python\", \"parallel_test.py\"], stdout=PIPE)\n",
    "# _, err = process.communicate(timeout=10)\n",
    "# exit_code = process.wait()\n",
    "# test_eq(exit_code, 0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|hide\n",
    "import nbdev; nbdev.nbdev_export()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
